/*
MIT License

Copyright (c) 2022 МГТУ им. Н.Э. Баумана, кафедра ИУ-6, Михаил Фетисов,

https://bmstu.codes/lsx/simodo/loom
*/

#include "ScriptSemantics_abstract.h"
#include "simodo/interpret/AnalyzeException.h"

#include "simodo/variable/FunctionWrapper.h"
#include "simodo/inout/convert/functions.h"
#include "simodo/inout/format/fmt.h"

#include <memory>
#include <cassert>
#include <limits>
#include <cmath>
#include <float.h>

namespace simodo::interpret
{
    void ScriptSemantics_abstract::unary(ScriptOperationCode opcode, const inout::TokenLocation & location)
    {
        variable::Variable & target = stack().variable(stack().top()).origin();

        if (target.value().isError())
            return;

        if (opcode == ScriptOperationCode::Not) {
            if (target.type() == variable::ValueType::Bool) {
                if (target.variant().index() == std::variant_npos)
                    return;
                target.value().variant() = !target.value().getBool(); // -V601
                target.setLocation(location);
                return;
            }

            variable::Variable var = inter().expr().convertVariable(target, variable::ValueType::Bool);

            if (target.variant().index() == std::variant_npos)
                return;

            target.value().variant() = !var.value().getBool(); // -V601
            target.setLocation(location);
            return;
        }

        if (opcode != ScriptOperationCode::Minus &&
            opcode != ScriptOperationCode::Plus)
            throw bormental::DrBormental("ScriptSemantics_abstract::unary", inout::fmt("Unsupported"));

        if (target.type() == variable::ValueType::Int) {
            if (target.variant().index() == std::variant_npos)
                return;
            if (opcode == ScriptOperationCode::Minus) {
                target.value().variant() = -(std::get<int64_t>(target.value().variant()));
                target.setLocation(location);
            }
            return;
        }

        if (target.type() == variable::ValueType::Float) {
            if (target.variant().index() == std::variant_npos)
                return;
            if (opcode == ScriptOperationCode::Minus)
                target.value().variant() = -(std::get<double>(target.value().variant()));
            return;
        }

        variable::Variable var = inter().expr().convertVariable(target, variable::ValueType::Int);

        if (var.type() != variable::ValueType::Int || target.variant().index() == std::variant_npos) {
            var = inter().expr().convertVariable(target, variable::ValueType::Float);
            if (var.type() != variable::ValueType::Float || target.variant().index() == std::variant_npos)
                return;
        }

        if (opcode == ScriptOperationCode::Minus) {
            target.value().variant() = -(std::get<double>(var.value().variant()));
            target.setLocation(location);
        }
    }

    void ScriptSemantics_abstract::logical(ScriptOperationCode opcode, const inout::TokenLocation & location)
    {
        variable::Variable & op1 = stack().variable(stack().top(1)).origin();
        variable::Variable & op2 = stack().variable(stack().top(0)).origin();

        if (op1.value().isError() || op2.value().isError()) {
            stack().pop(2);
            stack().push(variable::error_variable().copyVariable());
            return;
        }

        variable::Value res;

        if (op1.type() == variable::ValueType::Bool && op2.type() ==variable:: ValueType::Bool)
            res = performLogicalOperation(opcode, op1.value(), op2.value());
        else if (op1.type() == variable::ValueType::Bool)
            res = performLogicalOperation(opcode, op1.value(), inter().expr().convertVariable(op2,variable::ValueType::Bool).value());
        else if (op2.type() == variable::ValueType::Bool)
            res = performLogicalOperation(opcode, inter().expr().convertVariable(op1,variable::ValueType::Bool).value(), op2.value());
        else 
            res = performLogicalOperation(opcode, 
                                        inter().expr().convertVariable(op1,variable::ValueType::Bool).value(), 
                                        inter().expr().convertVariable(op2,variable::ValueType::Bool).value());
        
        stack().pop(2);
        stack().push({{}, res, location, {}});
    }

    void ScriptSemantics_abstract::compare(ScriptOperationCode opcode, const inout::TokenLocation & location)
    {
        variable::Variable & op1_origin = stack().variable(stack().top(1)).origin();
        variable::Variable & op2_origin = stack().variable(stack().top(0)).origin();

        if (op1_origin.value().isError() || op2_origin.value().isError()) {
            stack().pop(2);
            stack().push(variable::error_variable().copyVariable());
            return;
        }

        variable::ValueType target_type = getType4TypeConversion(opcode, op1_origin.type(), op2_origin.type());

        if (target_type == variable::ValueType::Null)
            throw AnalyzeException("ScriptSemantics_abstract::compare", 
                                    location.makeLocation(inter().files()), 
                                    inout::fmt("For operation %1, the use of types %2 and %3 is not provided")
                                        .arg(getSblOperationCodeName(opcode))
                                        .arg(getValueTypeName(op1_origin.type()))
                                        .arg(getValueTypeName(op2_origin.type())));

        variable::Variable  op1 = inter().expr().convertVariable(op1_origin, target_type);
        variable::Variable  op2 = inter().expr().convertVariable(op2_origin, target_type);
        variable::Value     res;

        switch(opcode)
        {
        case ScriptOperationCode::Equal:
            res = performCompareEqual(op1.value(), op2.value());
            break;
        case ScriptOperationCode::NotEqual:
            res = !performCompareEqual(op1.value(), op2.value()).getBool();
            break;
        case ScriptOperationCode::Less:
            res = performCompareLess(op1.value(), op2.value());
            break;
        case ScriptOperationCode::LessOrEqual:
            res = performCompareLessOrEqual(op1.value(), op2.value());
            break;
        case ScriptOperationCode::More:
            res = !performCompareLessOrEqual(op1.value(), op2.value()).getBool();
            break;
        case ScriptOperationCode::MoreOrEqual:
            res = !performCompareLess(op1.value(), op2.value()).getBool();
            break;
        default:
            break;
        }

        if (res.type() == variable::ValueType::Null)
            throw AnalyzeException("ScriptSemantics_abstract::compare", 
                                    location.makeLocation(inter().files()), 
                                    inout::fmt("For operation %1, the use of types %2 and %3 is not provided")
                                        .arg(getSblOperationCodeName(opcode))
                                        .arg(getValueTypeName(op1.type()))
                                        .arg(getValueTypeName(op2.type())));

        stack().pop(2);
        stack().push({{}, res, location, {}});
    }

    void ScriptSemantics_abstract::arithmetic(ScriptOperationCode opcode, const inout::TokenLocation & location)
    {
        variable::Variable & op1_origin = stack().variable(stack().top(1)).origin();
        variable::Variable & op2_origin = stack().variable(stack().top(0)).origin();

        if (op1_origin.value().isError() || op2_origin.value().isError()) {
            stack().pop(2);
            stack().push(variable::error_variable().copyVariable());
            return;
        }

        variable::ValueType target_type = getType4TypeConversion(opcode, op1_origin.type(), op2_origin.type());

        if (target_type == variable::ValueType::Null)
            throw AnalyzeException("ScriptSemantics_abstract::arithmetic", 
                                    location.makeLocation(inter().files()), 
                                    inout::fmt("For operation %1, the use of types %2 and %3 is not provided")
                                        .arg(getSblOperationCodeName(opcode))
                                        .arg(getValueTypeName(op1_origin.type()))
                                        .arg(getValueTypeName(op2_origin.type())));

        variable::Variable  op1 = inter().expr().convertVariable(op1_origin, target_type);
        variable::Variable  op2 = inter().expr().convertVariable(op2_origin, target_type);
        variable::Value     res = performArithmeticOperation(opcode, op1.value(), op2.value());

        if (res.type() == variable::ValueType::Null)
            throw AnalyzeException("ScriptSemantics_abstract::arithmetic", 
                                    location.makeLocation(inter().files()), 
                                    inout::fmt("For operation %1, the use of types %2 and %3 is not provided")
                                        .arg(getSblOperationCodeName(opcode))
                                        .arg(getValueTypeName(op1.type()))
                                        .arg(getValueTypeName(op2.type())));

        stack().pop(2);
        stack().push({{}, res, location, {}});
    }

    variable::ValueType ScriptSemantics_abstract::getType4TypeConversion(ScriptOperationCode /*opcode*/, variable::ValueType type1, variable::ValueType type2) const
    {
        if (type1 == type2)
            return type1;

        if ((type1 == variable::ValueType::Int && type2 == variable::ValueType::Float) ||
            (type1 == variable::ValueType::Float && type2 == variable::ValueType::Int))
            return variable::ValueType::Float;

        if (type1 == variable::ValueType::String || type2 == variable::ValueType::String)
            return variable::ValueType::String;

        return variable::ValueType::Null;
    }

    bool ScriptSemantics_abstract::maybeContract(const variable::Variable & var) const
    {
        const variable::Variable & spec_origin = var.spec().object()->getVariableByName(variable::SPEC_ORIGIN);

        if (/*var.type() != variable::ValueType::Record ||*/ spec_origin.type() != variable::ValueType::String)
            return false;

        std::u16string o_str = spec_origin.value().getString();

        if (o_str != variable::SPEC_ORIGIN_CONTRACT
         && o_str != variable::SPEC_ORIGIN_STRUCTURE 
         && o_str != variable::SPEC_ORIGIN_MODULE)
            return false;

        return true;
    }

    variable::Value ScriptSemantics_abstract::performLogicalOperation(ScriptOperationCode opcode, const variable::Value & op1, const variable::Value & op2) const
    {
        if (op1.variant().index() == std::variant_npos)
            return {variable::ValueType::Bool};

        bool v1 = op1.getBool();
        bool v2 = op2.getBool();

        switch(opcode)
        {
        case ScriptOperationCode::Or:
            return {bool {v1 || v2}};
        case ScriptOperationCode::And:
            return {bool {v1 && v2}};
        default:
            break;
        }

        return {};
    }

    variable::Value ScriptSemantics_abstract::performCompareEqual(const variable::Value & op1, const variable::Value & op2) const
    {
        switch(op1.type())
        {
        case variable::ValueType::Bool:
            if (op1.variant().index() != std::variant_npos && op2.variant().index() != std::variant_npos)
                return op1.getBool() == op2.getBool();
            else
                return variable::ValueType::Bool;
        case variable::ValueType::Int:
            if (op1.variant().index() != std::variant_npos && op2.variant().index() != std::variant_npos)
                return op1.getInt() == op2.getInt();
            else
                return variable::ValueType::Bool;
        case variable::ValueType::Float:
            if (op1.variant().index() != std::variant_npos && op2.variant().index() != std::variant_npos)
            {
                double val1 = op1.getReal();
                double val2 = op2.getReal();

                // Учёт погрешности вычислений с плавающей точкой.
                // Машинный эпсилон (разница между 1.0 и следующим представимым значением для double)
                // должен быть масштабирован до величины используемых значений и умножен на желаемую
                // точность в ULP (единицы на последнем месте). ULP нужно подбирать в зависимсоти от
                // степени накапливаемой погрешности.
                /// \todo Подобрать оптимальное значение ULP
                return std::abs(val1 - val2) <= DBL_EPSILON * std::abs(val1+val2) * 2;
            }
            else
                return variable::ValueType::Bool;
        case variable::ValueType::String:
            if (op1.variant().index() != std::variant_npos && op2.variant().index() != std::variant_npos)
                return op1.getString() == op2.getString();
            else
                return variable::ValueType::Bool;
        default:
            break;
        }

        return {};
    }

    variable::Value ScriptSemantics_abstract::performCompareLess(const variable::Value & op1, const variable::Value & op2) const
    {
        switch(op1.type())
        {
        case variable::ValueType::Int:
            if (op1.variant().index() != std::variant_npos && op2.variant().index() != std::variant_npos)
                return op1.getInt() < op2.getInt();
            else
                return variable::ValueType::Bool;
        case variable::ValueType::Float:
            if (op1.variant().index() != std::variant_npos && op2.variant().index() != std::variant_npos)
                return op1.getReal() < op2.getReal();
            else
                return variable::ValueType::Bool;
        case variable::ValueType::String:
            if (op1.variant().index() != std::variant_npos && op2.variant().index() != std::variant_npos)
                return op1.getString() < op2.getString();
            else
                return variable::ValueType::Bool;
        default:
            break;
        }

        return {};
    }

    variable::Value ScriptSemantics_abstract::performCompareLessOrEqual(const variable::Value & op1, const variable::Value & op2) const
    {
        switch(op1.type())
        {
        case variable::ValueType::Int:
            if (op1.variant().index() != std::variant_npos && op2.variant().index() != std::variant_npos)
                return op1.getInt() <= op2.getInt();
            else
                return variable::ValueType::Bool;
        case variable::ValueType::Float:
            if (op1.variant().index() != std::variant_npos && op2.variant().index() != std::variant_npos)
                return op1.getReal() <= op2.getReal();
            else
                return variable::ValueType::Bool;
        case variable::ValueType::String:
            if (op1.variant().index() != std::variant_npos && op2.variant().index() != std::variant_npos)
                return op1.getString() <= op2.getString();
            else
                return variable::ValueType::Bool;
        default:
            break;
        }

        return {};
    }

    variable::Value ScriptSemantics_abstract::performArithmeticOperation(ScriptOperationCode opcode, const variable::Value & op1, const variable::Value & op2) const
    {
        if (op1.variant().index() == std::variant_npos || op2.variant().index() == std::variant_npos)
            return op1.type();

        switch(opcode) 
        {
        case ScriptOperationCode::Addition:
        case ScriptOperationCode::AssignmentAddition:
            switch(op1.type()) {
            case variable::ValueType::Int:
                return op1.getInt() + op2.getInt();
            case variable::ValueType::Float:
                return op1.getReal() + op2.getReal();
            case variable::ValueType::String:
                return op1.getString() + op2.getString();
            default:
                break;
            }
            break;
        case ScriptOperationCode::Subtraction:
        case ScriptOperationCode::AssignmentSubtraction:
            switch(op1.type()) {
            case variable::ValueType::Int:
                return op1.getInt() - op2.getInt();
            case variable::ValueType::Float:
                return op1.getReal() - op2.getReal();
            default:
                break;
            }
            break;
        case ScriptOperationCode::Multiplication:
        case ScriptOperationCode::AssignmentMultiplication:
            switch(op1.type()) {
            case variable::ValueType::Int:
                return op1.getInt() * op2.getInt();
            case variable::ValueType::Float:
                return op1.getReal() * op2.getReal();
            default:
                break;
            }
            break;
        case ScriptOperationCode::Division:
        case ScriptOperationCode::AssignmentDivision:
            switch(op1.type()) {
            case variable::ValueType::Int:
                if (op2.getInt() == 0)
                    throw std::overflow_error(inout::fmt("Деление на ноль"));
                return op1.getInt() / op2.getInt();
            case variable::ValueType::Float: {
                    double val2 = op2.getReal();
                    if (std::abs(val2) <= DBL_EPSILON * std::abs(val2) * 2)
                        throw std::overflow_error(inout::fmt("Деление на ноль"));
                    return op1.getReal() / op2.getReal();
                }
            default:
                break;
            }
            break;
        case ScriptOperationCode::Modulo:
        case ScriptOperationCode::AssignmentModulo:
            switch(op1.type()) {
            case variable::ValueType::Int:
            {
                int64_t op1i = op1.getInt();
                int64_t op2i = op2.getInt();
                if (op2i == 0)
                    throw std::overflow_error(inout::fmt("Деление на ноль"));
                return op1i % op2i;
            }
            default:
                break;
            }
            break;
        case ScriptOperationCode::Power:
            switch(op1.type()) {
            /// \todo Возведение в степень для целых чисел всегда возвращает вещественный результат.
            case variable::ValueType::Int: {
                    double res = pow(op1.getInt(), op2.getInt());
                    return res;
                }
            case variable::ValueType::Float: {
                    double res = pow(op1.getReal(), op2.getReal());
                    return res;
                }
            default:
                break;
            }
            break;
        default:
            break;
        }

        return {};
    }

}
